# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from os import path as osp
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.image import tensor2imgs
from mmcv.parallel import collate, scatter
from mmdet3d.core import Box3DMode
from mmdet3d.core.bbox import get_box_type
from mmdet3d.datasets.pipelines import Compose
from mmdet3d.models import (Base3DDetector, Base3DSegmentor,
                            SingleStageMono3DDetector)
from torch.utils.data import Dataset

from mmdeploy.codebase.base import BaseTask
from mmdeploy.codebase.mmdet3d.deploy.mmdetection3d import MMDET3D_TASK
from mmdeploy.utils import Task, get_root_logger
from mmdeploy.utils.config_utils import is_dynamic_shape


@MMDET3D_TASK.register_module(Task.MONOCULAR_DETECTION.value)
class MonocularDetection(BaseTask):

    def __init__(self, model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
                 device: str):
        super().__init__(model_cfg, deploy_cfg, device)

    def init_backend_model(self,
                           model_files: Sequence[str] = None,
                           **kwargs) -> torch.nn.Module:
        """Initialize backend model.

        Args:
            model_files (Sequence[str]): Input model files.
        Returns:
            nn.Module: An initialized backend model.
        """
        from .monocular_detection_model import build_monocular_detection_model
        model = build_monocular_detection_model(
            model_files, self.model_cfg, self.deploy_cfg, device=self.device)
        return model

    def init_pytorch_model(self,
                           model_checkpoint: Optional[str] = None,
                           cfg_options: Optional[Dict] = None,
                           **kwargs) -> torch.nn.Module:
        """Initialize torch model.

        Args:
            model_checkpoint (str): The checkpoint file of torch model,
                defaults to `None`.
            cfg_options (dict): Optional config key-pair parameters.
        Returns:
            nn.Module: An initialized torch model generated by other OpenMMLab
                codebases.
        """
        from mmdet3d.apis import init_model
        device = self.device
        model = init_model(self.model_cfg, model_checkpoint, device)
        return model.eval()

    def create_input(self,
                     imgs: Union[str, np.ndarray],
                     input_shape: Optional[Sequence[int]] = None,
                     pipeline_updater: Optional[Callable] = None, **kwargs) \
            -> Tuple[Dict, torch.Tensor]:
        """Create input for detector.

        Args:
            input_shape (Sequence[int] | None): Input shape of image in
                (width, height) format, defaults to `None`.
            input_shape (Sequence[int] | None): Input shape of image in
                (width, height) format, defaults to `None`.
            pipeline_updater (function | None): A function to get a new
                pipeline.

        Returns:
            tuple: (data, img), meta information for the input image and input.
        """
        dynamic_flag = is_dynamic_shape(self.deploy_cfg)
        cfg = self.model_cfg
        # Drop pad_to_square when static shape. Because static shape should
        # ensure the shape before input image.
        if not dynamic_flag:
            transform = cfg.data.test.pipeline[1]
            if 'transforms' in transform:
                transform_list = transform['transforms']
                for i, step in enumerate(transform_list):
                    if step['type'] == 'Pad' and 'pad_to_square' in step \
                       and step['pad_to_square']:
                        transform_list.pop(i)
                        break
        # build the data pipeline
        test_pipeline = deepcopy(cfg.data.test.pipeline)
        test_pipeline = Compose(test_pipeline)
        box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d)
        # get  info
        ann_file = self.deploy_cfg.codebase_config.ann_file
        data_infos = mmcv.load(ann_file)
        # find the info corresponding to this image
        for x in data_infos['images']:
            if osp.basename(x['file_name']) != osp.basename(imgs):
                continue
            img_info = x
            break
        data = dict(
            img_prefix=osp.dirname(imgs),
            img_info=dict(filename=osp.basename(imgs)),
            box_type_3d=box_type_3d,
            box_mode_3d=box_mode_3d,
            img_fields=[],
            bbox3d_fields=[],
            pts_mask_fields=[],
            pts_seg_fields=[],
            bbox_fields=[],
            mask_fields=[],
            seg_fields=[])

        # camera points to image conversion
        if box_mode_3d == Box3DMode.CAM:
            data['img_info'].update(
                dict(cam_intrinsic=img_info['cam_intrinsic']))

        data = test_pipeline(data)

        data = collate([data], samples_per_gpu=1)
        data['img_metas'] = [
            img_metas.data[0] for img_metas in data['img_metas']
        ]
        data['img'] = [img.data[0] for img in data['img']]
        data['cam2img'] = [torch.tensor(data['img_metas'][0][0]['cam2img'])]
        data['cam2img_inverse'] = [torch.inverse(data['cam2img'][0])]
        if self.device != 'cpu':
            # scatter to specified GPU
            data = scatter(data, [self.device])[0]

        return data, tuple(data['img'] + data['cam2img'] +
                           data['cam2img_inverse'])

    def visualize(self,
                  model: torch.nn.Module,
                  image: str,
                  result: list,
                  output_file: str,
                  window_name: str,
                  show_result: bool = False,
                  score_thr: float = 0.3):
        """Visualize predictions of a model.

        Args:
            model (nn.Module): Input model.
            image (str): Pcd file to draw predictions on.
            result (list): A list of predictions.
            output_file (str): Output file to save result.
            window_name (str): The name of visualization window. Defaults to
                an empty string.
            show_result (bool): Whether to show result in windows, defaults
                to `False`.
            score_thr (float): The score threshold to display the bbox.
                Defaults to 0.3.
        """
        if output_file.endswith('.jpg'):
            output_file = output_file.split('.')[0]
        from mmdet3d.apis import show_result_meshlab
        data, _ = self.create_input(image)
        show_result_meshlab(
            data,
            result,
            output_file,
            score_thr,
            show=show_result,
            snapshot=1 - show_result,
            task='mono-det')

    @staticmethod
    def run_inference(model: nn.Module,
                      model_inputs: Dict[str, torch.Tensor]) -> List:
        """Run inference once for a object detection model of mmdet3d.

        Args:
            model (nn.Module): Input model.
            model_inputs (dict): A dict containing model inputs tensor and
                meta info.
        Returns:
            list: The predictions of model inference.
        """
        return [
            model(
                model_inputs['img'],
                model_inputs['img_metas'],
                return_loss=False,
                rescale=True)
        ]

    @staticmethod
    def evaluate_outputs(model_cfg,
                         outputs: Sequence,
                         dataset: Dataset,
                         metrics: Optional[str] = None,
                         out: Optional[str] = None,
                         metric_options: Optional[dict] = None,
                         format_only: bool = False,
                         log_file: Optional[str] = None):
        if out:
            logger = get_root_logger()
            logger.info(f'\nwriting results to {out}')
            mmcv.dump(outputs, out)
        kwargs = {} if metric_options is None else metric_options
        if format_only:
            dataset.format_results(outputs, **kwargs)
        if metrics:
            eval_kwargs = model_cfg.get('evaluation', {}).copy()
            # hard-code way to remove EvalHook args
            for key in [
                    'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
                    'rule'
            ]:
                eval_kwargs.pop(key, None)
                eval_kwargs.pop(key, None)
            eval_kwargs.update(dict(metric=metrics, **kwargs))
            print(dataset.evaluate(outputs, **eval_kwargs))

    def get_model_name(self) -> str:
        """Get the model name.

        Return:
            str: the name of the model.
        """
        raise NotImplementedError

    def get_tensor_from_input(self, input_data: Dict[str, Any],
                              **kwargs) -> torch.Tensor:
        """Get input tensor from input data.

        Args:
            input_data (dict): Input data containing meta info and image
                tensor.
        Returns:
            torch.Tensor: An image in `Tensor`.
        """
        raise NotImplementedError

    def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
        """Get a certain partition config for mmdet.

        Args:
            partition_type (str): A string specifying partition type.
        Returns:
            dict: A dictionary of partition config.
        """
        raise NotImplementedError

    def get_postprocess(self) -> Dict:
        """Get the postprocess information for SDK.

        Return:
            dict: Composed of the postprocess information.
        """
        raise NotImplementedError

    def get_preprocess(self) -> Dict:
        """Get the preprocess information for SDK.

        Return:
            dict: Composed of the preprocess information.
        """
        raise NotImplementedError

    def single_gpu_test(self,
                        model,
                        data_loader,
                        show=False,
                        out_dir=None,
                        show_score_thr=0.3):
        """Test model with single gpu.

        This method tests model with single gpu and gives the 'show' option.
        By setting ``show=True``, it saves the visualization results under
        ``out_dir``.
        Args:
            model (nn.Module): Model to be tested.
            data_loader (nn.Dataloader): Pytorch data loader.
            show (bool, optional): Whether to save viualization results.
                Default: True.
            out_dir (str, optional): The path to save visualization results.
                Default: None.
            show_score_thr (float): The score threshold for visulization
                Default is 0.3
        Returns:
            list[dict]: The prediction results.
        """
        results = []
        dataset = data_loader.dataset
        prog_bar = mmcv.ProgressBar(len(dataset))
        for i, data in enumerate(data_loader):
            with torch.no_grad():
                result = model(return_loss=False, rescale=True, **data)

            if show:
                # Visualize the results of MMDetection3D model
                # 'show_results' is MMdetection3D visualization API
                models_3d = (Base3DDetector, Base3DSegmentor,
                             SingleStageMono3DDetector)
                if isinstance(model.module, models_3d):
                    model.module.show_results(
                        data,
                        result,
                        out_dir=out_dir,
                        show=show,
                        score_thr=show_score_thr)
                # Visualize the results of MMDetection model
                # 'show_result' is MMdetection visualization API
                else:
                    batch_size = len(result)
                    if batch_size == 1 and isinstance(data['img'][0],
                                                      torch.Tensor):
                        img_tensor = data['img'][0]
                    else:
                        img_tensor = data['img'][0].data[0]
                    img_metas = data['img_metas'][0].data[0]
                    imgs = tensor2imgs(img_tensor,
                                       **img_metas[0]['img_norm_cfg'])
                    assert len(imgs) == len(img_metas)

                    for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
                        h, w, _ = img_meta['img_shape']
                        img_show = img[:h, :w, :]

                        ori_h, ori_w = img_meta['ori_shape'][:-1]
                        img_show = mmcv.imresize(img_show, (ori_w, ori_h))

                        if out_dir:
                            out_file = osp.join(out_dir,
                                                img_meta['ori_filename'])
                        else:
                            out_file = None

                        model.module.show_result(
                            img_show,
                            result[i],
                            show=show,
                            out_file=out_file,
                            score_thr=show_score_thr)
            results.extend(result)

            batch_size = 1
            for _ in range(batch_size):
                prog_bar.update()
        return results
